Skip to content

[PyTorch] Backwards compatible single param checkpointing in GroupedLinear#2761

Merged
ksivaman merged 7 commits intoNVIDIA:mainfrom
ksivaman:backwards_compatible_single_param_checkpointing
Mar 16, 2026
Merged

[PyTorch] Backwards compatible single param checkpointing in GroupedLinear#2761
ksivaman merged 7 commits intoNVIDIA:mainfrom
ksivaman:backwards_compatible_single_param_checkpointing

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

GroupedLinear module supports either a single parameter registration via GroupedTensor or one param per expert. This PR supports checkpointing loading compatibility across those options.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Allow conversion of checkpoint from 1 param format to another.
  • Add checkpointing test to verify functionality.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 13, 2026

Greptile Summary

This PR adds backwards-compatible checkpoint loading to GroupedLinear, enabling state dicts saved in the per-GEMM format (weight0..N) to be loaded into a single_grouped_parameter=True model and vice versa. The implementation hooks into load_state_dict and _load_from_state_dict via a new _remap_grouped_weight_state_dict_keys helper that rewrites the relevant state dict keys before PyTorch's standard loading logic runs.

Key changes:

  • _remap_grouped_weight_state_dict_keys: translates between weight (grouped) and weight0..N (per-GEMM) key formats in-place, handling dequantization via QuantizedTensorStorage in both directions.
  • load_state_dict override: makes a shallow copy of the caller's dict (preserving _metadata) and remaps before delegating to super().
  • _load_from_state_dict override: ensures the remap also fires when GroupedLinear is loaded as a nested submodule (where load_state_dict is not the entry point).
  • Two new tests cover both conversion directions using float32 weights and strict loading.

Issues found:

  • assign=True is silently broken for the multi-to-single case: _remap_grouped_weight_state_dict_keys inserts a plain torch.stack tensor under "weight", and PyTorch's assign=True path replaces the GroupedTensor parameter with this plain tensor rather than calling copy_, breaking subsequent forward passes.
  • _load_from_state_dict mutates the shared state dict passed by the parent module, which is a non-obvious side effect.
  • No validation that the number of tensors recovered from a grouped checkpoint matches self.num_gemms, leading to cryptic "unexpected key" errors on misconfigured loads.

Confidence Score: 3/5

  • Safe to merge for the default assign=False path, but assign=True will silently corrupt the module's weight parameter type during cross-format loading.
  • The core compatibility logic is well-structured and the two tests validate the happy path. However, the assign=True bug is a real correctness issue that would be invisible to the user (no exception, but a broken GroupedTensor replaced by a plain tensor). The in-place mutation of the shared state dict in _load_from_state_dict is also a latent maintenance hazard. These issues prevent a higher confidence score.
  • transformer_engine/pytorch/module/grouped_linear.py — specifically the load_state_dict and _load_from_state_dict overrides, and the assign=True interaction with _remap_grouped_weight_state_dict_keys.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds _remap_grouped_weight_state_dict_keys, load_state_dict, and _load_from_state_dict overrides to support cross-format checkpoint loading between single-grouped-parameter and per-GEMM formats. Key concerns: assign=True silently replaces GroupedTensor with a plain tensor when doing multi-to-single conversion; in-place mutation of the shared state dict inside _load_from_state_dict is a side effect that may surprise callers; no GEMM-count validation on split.
tests/pytorch/test_grouped_tensor.py Adds two round-trip checkpoint tests covering multi-to-single and single-to-multi parameter format conversions. Tests correctly rely on split_into_quantized_tensors() returning views (documented behavior). Only covers assign=False (default) path; assign=True is not exercised.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["User calls GroupedLinear.load_state_dict(state_dict, strict, assign)"] --> B["Make shallow copy: state_dict_copy"]
    B --> C["_remap_grouped_weight_state_dict_keys(state_dict_copy, prefix='')"]
    C --> D{single_grouped_parameter?}
    D -- "True\n(want 'weight')" --> E{has weight0..N\nbut no weight?}
    E -- Yes --> F["Stack weight0..N → plain torch.Tensor\nInsert as 'weight'"]
    E -- No --> G["Drop redundant per-GEMM keys"]
    D -- "False\n(want weight0..N)" --> H{has 'weight'\nbut no weight0..N?}
    H -- Yes --> I{is GroupedTensor?}
    I -- Yes --> J["split_into_quantized_tensors()\nor use .quantized_tensors"]
    I -- No --> K["unbind(dim=0)"]
    J --> L["Insert weight0..N as plain tensors"]
    K --> L
    H -- No --> M["Drop redundant 'weight' key"]
    F --> N["super().load_state_dict(state_dict_copy)"]
    G --> N
    L --> N
    M --> N
    N --> O["PyTorch recursive loading\n→ calls _load_from_state_dict"]
    O --> P["_remap_grouped_weight_state_dict_keys again\n(idempotent, no-op)"]
    P --> Q["super()._load_from_state_dict()"]
    Q --> R{assign=True?}
    R -- "No (default)" --> S["param.copy_(state_dict_value)\nGroupedTensor __torch_dispatch__ handles copy"]
    R -- "Yes ⚠️" --> T["setattr(module, 'weight', plain_tensor)\nGroupedTensor replaced by plain tensor — BUG"]
Loading

Last reviewed commit: ebd23b9

Comment on lines +901 to +908
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
"""Load state dict with grouped-weight format compatibility."""
state_dict_copy = state_dict.copy()
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
state_dict_copy._metadata = metadata
self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="")
return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double remapping of weight keys

_remap_grouped_weight_state_dict_keys is applied twice whenever GroupedLinear.load_state_dict is the entry point:

  1. Explicitly in load_state_dict (line 907).
  2. Again inside GroupedLinear._load_from_state_dict (line 914), which PyTorch's super().load_state_dict() invokes internally as part of its recursive loading loop.

The second call is idempotent — after the first remap the state dict is already in the expected format, so the second remap is a no-op — but the redundancy is a maintenance hazard: a future change that makes the remap non-idempotent could silently introduce data corruption (e.g. double-stacking weights).

A straightforward fix is to skip the remap inside load_state_dict and let _load_from_state_dict handle it exclusively (which already covers the nested-module case). The copy is still needed to avoid mutating the caller's dict, so it should be preserved:

def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
    """Load state dict with grouped-weight format compatibility."""
    state_dict_copy = state_dict.copy()
    metadata = getattr(state_dict, "_metadata", None)
    if metadata is not None:
        state_dict_copy._metadata = metadata
    # Key remapping is performed in _load_from_state_dict which PyTorch
    # calls internally; no need to remap again here.
    return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)

This keeps the copy (protecting the caller's dict) and relies on _load_from_state_dict for the single, canonical remap path in all cases.

torch.save(src.state_dict(), ckpt_path)
del src

src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights_only=False enables arbitrary pickle execution

torch.load(..., weights_only=False) deserialises the file using Python's pickle module, which executes arbitrary code embedded in the file. PyTorch 2.x already emits a FutureWarning for this pattern and the default will flip to True in a future release.

For the multi-to-single test (test_grouped_linear_load_state_dict_multi_to_single_param) the source model uses single_grouped_parameter=False, so all saved tensors are plain torch.Tensor objects — weights_only=True should work fine there.

Suggested change
src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
src_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)

The same concern applies to line 540 in test_grouped_linear_load_state_dict_single_to_multi_param. For that test the saved weight is a GroupedTensor subclass, which may require weights_only=False to deserialise; if so, the incompatibility should be documented with an inline comment explaining why weights_only=True cannot be used.

@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci pytorch

expected_weights = [getattr(src, f"weight{i}").detach().clone() for i in range(num_gemms)]
ckpt_path = tmp_path / "grouped_linear_per_gemm.pt"
torch.save(src.state_dict(), ckpt_path)
del src
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add test case for quantized_model_init(mxfp8)? Shouldnt be a blocker for this PR though.

Comment on lines +901 to +908
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
"""Load state dict with grouped-weight format compatibility."""
state_dict_copy = state_dict.copy()
metadata = getattr(state_dict, "_metadata", None)
if metadata is not None:
state_dict_copy._metadata = metadata
self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="")
return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assign=True replaces GroupedTensor with a plain tensor

When assign=True is passed to load_state_dict and the multi-to-single conversion is active (single_grouped_parameter=True, checkpoint has weight0..N), _remap_grouped_weight_state_dict_keys writes a plain torch.Tensor (from torch.stack) into state_dict_copy["weight"]. PyTorch's assign=True path then calls setattr(module, "weight", plain_tensor) instead of param.copy_(plain_tensor), so the GroupedTensor parameter is silently replaced by a plain tensor. Any subsequent forward pass that calls self.weight.split_into_quantized_tensors() or relies on the GroupedTensor.__torch_dispatch__ mechanism will crash or silently compute incorrect results.

A fix is to either document that assign=True is unsupported for cross-format loading, or reconstruct a proper GroupedTensor inside the remap helper when the target format is single_grouped_parameter=True:

def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
    """Load state dict with grouped-weight format compatibility."""
    if assign:
        warnings.warn(
            "GroupedLinear.load_state_dict with assign=True does not support "
            "cross-format checkpoint loading. Use assign=False (default).",
            UserWarning,
        )
    state_dict_copy = state_dict.copy()
    metadata = getattr(state_dict, "_metadata", None)
    if metadata is not None:
        state_dict_copy._metadata = metadata
    self._remap_grouped_weight_state_dict_keys(state_dict_copy, prefix="")
    return super().load_state_dict(state_dict_copy, strict=strict, assign=assign)

Comment on lines +874 to +896
if not has_per_gemm_weights and has_grouped_weight:
grouped_weight = state_dict.pop(grouped_weight_key)
if hasattr(grouped_weight, "split_into_quantized_tensors"):
grouped_members = grouped_weight.quantized_tensors
if grouped_members is None:
grouped_members = grouped_weight.split_into_quantized_tensors()
per_gemm_weights = [
(
weight.dequantize()
if isinstance(weight, QuantizedTensorStorage)
else weight
)
for weight in grouped_members
]
else:
grouped_weight = (
grouped_weight.dequantize()
if isinstance(grouped_weight, QuantizedTensorStorage)
else grouped_weight
)
per_gemm_weights = list(grouped_weight.unbind(dim=0))
for i, weight in enumerate(per_gemm_weights):
state_dict[f"{prefix}weight{i}"] = weight
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No validation of GEMM count after splitting

When splitting a grouped checkpoint into per-GEMM weights, neither the split_into_quantized_tensors() path nor the unbind(dim=0) path validates that the number of recovered tensors equals self.num_gemms. If the checkpoint was created with a different number of GEMMs (e.g., num_gemms=5 saved, num_gemms=3 loaded), the remap will silently inject weight0..4 into the state dict. With strict=True, PyTorch will then report weight3 and weight4 as unexpected keys — but the diagnostic message gives no hint that the root cause is a GEMM-count mismatch.

Adding an explicit early check here improves debuggability:

if hasattr(grouped_weight, "split_into_quantized_tensors"):
    grouped_members = grouped_weight.quantized_tensors
    if grouped_members is None:
        grouped_members = grouped_weight.split_into_quantized_tensors()
    if len(grouped_members) != self.num_gemms:
        raise ValueError(
            f"Checkpoint grouped weight contains {len(grouped_members)} GEMMs "
            f"but this module was configured with num_gemms={self.num_gemms}."
        )
    ...
else:
    per_gemm_weights = list(grouped_weight.unbind(dim=0))
    if len(per_gemm_weights) != self.num_gemms:
        raise ValueError(
            f"Checkpoint stacked weight has {len(per_gemm_weights)} slices along dim=0 "
            f"but this module was configured with num_gemms={self.num_gemms}."
        )

Comment on lines +910 to +918
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""Load state, including compatibility across grouped-weight checkpoint formats."""
self._remap_grouped_weight_state_dict_keys(state_dict, prefix)

super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_load_from_state_dict mutates the shared state dict in-place

When GroupedLinear is used as a submodule, PyTorch passes the same state_dict object (with the full prefix tree) to every module's _load_from_state_dict. This override calls _remap_grouped_weight_state_dict_keys, which modifies that shared dict in-place — popping old keys and inserting new ones (e.g. swapping "parent.grouped.weight" out for "parent.grouped.weight0..N").

While the key operations are scoped by prefix and don't touch other modules' keys, the mutation is a side-effect that:

  1. Permanently alters the caller's state dict after the fact (the user may not expect their dict to be modified when loading a submodule).
  2. Interacts unexpectedly with the unexpected_keys accounting in PyTorch's base _load_from_state_dict if the newly injected keys are not all consumed.

A defensive pattern is to work on a shallow copy of only the module's relevant key-space, similar to what load_state_dict already does at the top level. At minimum, adding a comment here that the mutation is intentional and scoped would reduce the maintenance burden.

state_dict[grouped_weight_key] = torch.stack(per_gemm_weights, dim=0)
elif has_grouped_weight:
# Drop any redundant per-GEMM keys to avoid strict-load unexpected-key errors.
for key in per_gemm_weight_keys:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need this even for TE sequential checkpointing right? Maybe putting it in utils and using it in both places make sense to avoid code duplication

@ksivaman ksivaman merged commit 4017565 into NVIDIA:main Mar 16, 2026
20 of 24 checks passed
@ksivaman ksivaman deleted the backwards_compatible_single_param_checkpointing branch March 16, 2026 19:09
KshitijLakhani pushed a commit that referenced this pull request Mar 20, 2026
…Linear` (#2761)

* Load multi-param checkpoint from single-param config in GroupedLinear

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Multi-param to single param case

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Multi-param to single param case

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better varnames

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
vthumbe1503 pushed a commit to ksivaman/TransformerEngine-1 that referenced this pull request Apr 1, 2026
…Linear` (NVIDIA#2761)

* Load multi-param checkpoint from single-param config in GroupedLinear

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Multi-param to single param case

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Multi-param to single param case

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better varnames

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants